# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
import os
import shutil
import tempfile
import unittest

import numpy as np
import paddle
import paddle.device.cuda.graphs as graphs

from fastdeploy.config import (
    CacheConfig,
    CommitConfig,
    DeviceConfig,
    EarlyStopConfig,
    FDConfig,
    GraphOptimizationConfig,
    LoadConfig,
    ModelConfig,
    ParallelConfig,
    SchedulerConfig,
    SpeculativeConfig,
)
from fastdeploy.model_executor.forward_meta import ForwardMeta, ForwardMode
from fastdeploy.model_executor.layers.attention import (
    AttentionBackend,
    get_attention_backend,
)
from fastdeploy.model_executor.layers.attention.append_attn_backend import (
    allocate_launch_related_buffer,
)
from fastdeploy.model_executor.layers.quantization.mix_quant import MixQuantConfig
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_Attention
from fastdeploy.model_executor.ops.gpu import get_padding_offset

if "nvidia graphics device" in paddle.device.cuda.get_device_name().lower():
    # (ZKK): CI machine.
    os.environ.setdefault("DG_NVCC_OVERRIDE_CPP_STANDARD", "17")


class TestAttentionPerformance(unittest.TestCase):
    def setUp(self):
        """
        Set up the testing environment before each test.
        This includes creating configurations, initializing the model,
        and preparing a random state dictionary.
        """
        print("Setting up test environment...")
        paddle.set_device("gpu")
        paddle.set_default_dtype("bfloat16")

        self.model_dir = self.create_model_config_json()
        self.fd_config = self.create_fd_config_from_model_path(self.model_dir, tensor_parallel_size=1)
        self.fd_config.parallel_config.tp_group = [0]

        # Initialize Attention Layer
        attn_cls = get_attention_backend()
        self.attn_backend = attn_cls(
            self.fd_config,
            kv_num_heads=self.fd_config.model_config.num_key_value_heads
            // self.fd_config.parallel_config.tensor_parallel_size,
            num_heads=self.fd_config.model_config.num_attention_heads
            // self.fd_config.parallel_config.tensor_parallel_size,
            head_dim=self.fd_config.model_config.head_dim,
            encoder_block_shape_q=64,
            decoder_block_shape_q=16,
        )

        num_layers = self.fd_config.model_config.num_hidden_layers
        self.attention_layer = [None] * num_layers
        for i in range(num_layers):
            self.attention_layer[i] = Ernie4_5_Attention(self.fd_config, layer_id=i, prefix="test_layer")
            state_dict = self.create_random_attention_state_dict(self.fd_config, prefix="test_layer")
            self.attention_layer[i].load_state_dict(state_dict)

        def attn_forward(forward_meta, hidden_states):
            for i in range(num_layers):
                hidden_states = self.attention_layer[i](forward_meta, hidden_states)

            return hidden_states

        self.attn_forward = attn_forward

        self.cache_quant_type_str = getattr(self.attention_layer[0].attn, "cache_quant_type_str", "none")

        print("===== Initialization Complete =====")

    def tearDown(self):
        """
        Clean up the environment after each test.
        """
        print("\nTearing down test environment...")
        if os.path.exists(self.model_dir):
            shutil.rmtree(self.model_dir)
            print(f"Successfully removed temporary directory: {self.model_dir}")

    # region Helper Functions
    def create_model_config_json(self) -> str:
        """
        Creates a temporary directory and writes the model configuration to a 'config.json' file.
        """
        config_dict = {
            "architectures": ["Ernie4_5_MoeForCausalLM"],
            "dtype": "bfloat16",
            "max_position_embeddings": 131072,
            "max_model_len": 131072,
            "head_dim": 128,
            "hidden_size": 8192,
            "num_attention_heads": 64,
            "num_key_value_heads": 8,
            "num_hidden_layers": 2,
        }
        model_dir = tempfile.mkdtemp(prefix="tmp_model_config_")
        config_path = os.path.join(model_dir, "config.json")
        with open(config_path, "w") as f:
            json.dump(config_dict, f, indent=4)
        print(f"Successfully created config.json at: {config_path}")
        return model_dir

    def create_fd_config_from_model_path(self, model_path, tensor_parallel_size=1):
        """Creates a complete FDConfig from a model path."""
        model_args = {"model": model_path, "dtype": "bfloat16"}
        model_config = ModelConfig(model_args)
        model_config.tensor_parallel_size = tensor_parallel_size
        parallel_config = ParallelConfig({"tensor_parallel_size": tensor_parallel_size, "data_parallel_size": 1})
        cache_config = CacheConfig(
            {
                "block_size": 64,
                "model_cfg": model_config,
                "tensor_parallel_size": tensor_parallel_size,
            }
        )
        return FDConfig(
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=SchedulerConfig({}),
            load_config=LoadConfig({}),
            quant_config=MixQuantConfig(
                dense_quant_type="block_wise_fp8",
                moe_quant_type="block_wise_fp8",
                kv_cache_quant_type="float8_e4m3fn",
                # kv_cache_quant_type=None,
            ),
            graph_opt_config=GraphOptimizationConfig({}),
            commit_config=CommitConfig(),
            device_config=DeviceConfig({}),
            speculative_config=SpeculativeConfig({}),
            early_stop_config=EarlyStopConfig({}),
        )

    def create_random_attention_state_dict(self, fd_config: FDConfig, prefix: str) -> dict:
        """
        Creates a state_dict with random weights for the Ernie4_5_Attention layer.
        """
        hidden_size = fd_config.model_config.hidden_size
        tp_size = fd_config.parallel_config.tensor_parallel_size
        tensor_dtype = getattr(paddle, fd_config.model_config.dtype)

        q_dims = fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim
        kv_dims = fd_config.model_config.num_key_value_heads * fd_config.model_config.head_dim
        total_output_dim = q_dims + 2 * kv_dims
        qkv_proj_output_dim_tp = total_output_dim // tp_size
        qkv_weight_shape = [hidden_size, qkv_proj_output_dim_tp]

        o_proj_input_dim = fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim
        o_proj_input_dim_tp = o_proj_input_dim // tp_size
        o_proj_weight_shape = [o_proj_input_dim_tp, hidden_size]

        qkv_weight = paddle.randn(qkv_weight_shape, dtype=tensor_dtype)
        o_proj_weight = paddle.randn(o_proj_weight_shape, dtype=tensor_dtype)

        kv_num_heads_tp = fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_size
        activation_scale_shape = [kv_num_heads_tp]
        activation_scale_tensor = paddle.full(shape=activation_scale_shape, fill_value=1.0, dtype=tensor_dtype)

        state_dict = {
            f"{prefix}.qkv_proj.weight": qkv_weight,
            f"{prefix}.o_proj.weight": o_proj_weight,
            f"{prefix}.cachek_matmul.activation_scale": activation_scale_tensor,
            f"{prefix}.cachev_matmul.activation_scale": activation_scale_tensor,
        }
        return state_dict

    def create_forward_meta(
        self,
        batch_size: int,
        seq_len: int,
        mode: ForwardMode,
        fd_config: FDConfig,
        attn_backend: AttentionBackend,
        cache_quant_type_str: str = "none",
    ) -> ForwardMeta:
        """
        Creates a high-fidelity ForwardMeta object.
        """
        if mode == ForwardMode.EXTEND:
            seq_lens_encoder = paddle.full([batch_size], seq_len, dtype="int32")
            seq_lens_decoder = paddle.zeros([batch_size], dtype="int32")
            seq_lens_this_time = seq_lens_encoder
        elif mode == ForwardMode.DECODE:
            seq_lens_encoder = paddle.zeros([batch_size], dtype="int32")
            seq_lens_decoder = paddle.full([batch_size], seq_len, dtype="int32")
            seq_lens_this_time = paddle.ones([batch_size], dtype="int32")
        else:
            raise ValueError(f"Unsupported ForwardMode: {mode}")

        attn_backend_buffers = allocate_launch_related_buffer(
            max_batch_size=batch_size,
            max_model_len=fd_config.model_config.max_model_len,
            encoder_block_shape_q=64,
            decoder_block_shape_q=16,
            decoder_step_token_num=fd_config.speculative_config.num_speculative_tokens + 1,
            num_heads=fd_config.model_config.num_attention_heads,
            kv_num_heads=fd_config.model_config.num_key_value_heads,
            block_size=fd_config.cache_config.block_size,
        )

        block_size = fd_config.cache_config.block_size
        max_model_len = fd_config.model_config.max_model_len
        max_blocks_per_seq = (max_model_len + block_size - 1) // block_size
        allocated_blocks_per_seq = seq_len // block_size + 1
        allocated_num_blocks = allocated_blocks_per_seq * batch_size
        head_dim = fd_config.model_config.head_dim
        kv_num_heads_tp = fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_size
        num_layers = fd_config.model_config.num_hidden_layers
        cache_type = fd_config.model_config.dtype
        if cache_quant_type_str != "none":
            cache_type = "uint8"
        cache_shape = (allocated_num_blocks, kv_num_heads_tp, block_size, head_dim)
        scale_shape = (allocated_num_blocks, kv_num_heads_tp, block_size)
        caches = []
        for _ in range(num_layers):
            key_cache = paddle.randint(0, 255, shape=cache_shape, dtype="int32").cast(cache_type)
            value_cache = paddle.randint(0, 255, shape=cache_shape, dtype="int32").cast(cache_type)
            caches.extend([key_cache, value_cache])
            if cache_quant_type_str == "block_wise_fp8":
                key_cache_scale = paddle.rand(shape=scale_shape, dtype=fd_config.model_config.dtype)
                value_cache_scale = paddle.rand(shape=scale_shape, dtype=fd_config.model_config.dtype)
                caches.extend([key_cache_scale, value_cache_scale])

        block_tables = paddle.zeros(shape=(batch_size, max_blocks_per_seq), dtype="int32")
        for i in range(batch_size):
            for j in range(allocated_blocks_per_seq):
                block_tables[i, j] = i * allocated_blocks_per_seq + j

        tmp_position_ids = paddle.arange(fd_config.model_config.max_model_len).reshape((1, -1))
        rope_emb = get_rope(
            rotary_dim=fd_config.model_config.head_dim,
            position_ids=tmp_position_ids,
            base=fd_config.model_config.rope_theta,
            model_config=fd_config.model_config,
            partial_rotary_factor=fd_config.model_config.partial_rotary_factor,
        )

        input_ids = paddle.zeros([batch_size, max_model_len], dtype="int64")
        token_num = paddle.sum(seq_lens_this_time)
        ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
            input_ids, token_num, seq_lens_this_time
        )

        forward_meta = ForwardMeta(
            ids_remove_padding=ids_remove_padding,
            seq_lens_encoder=seq_lens_encoder,
            seq_lens_decoder=seq_lens_decoder,
            seq_lens_this_time=seq_lens_this_time,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            batch_id_per_token=batch_id_per_token,
            block_tables=block_tables,
            caches=caches,
            rotary_embs=rope_emb,
            step_use_cudagraph=False,
            attn_backend=attn_backend,
            forward_mode=ForwardMode.MIXED,
            attn_mask=None,
            attn_mask_offsets=None,
            **attn_backend_buffers,
        )

        hidden_states = paddle.randn([token_num, self.fd_config.model_config.hidden_size], dtype="bfloat16")
        return forward_meta, hidden_states

    def test_decode_performance_with_prefill(self):
        # Test parameters
        test_steps = 100

        # prefill_batch_size = 1
        # prefill_seq_len = 4096

        # prefill_hidden_states = paddle.randn(
        #     [prefill_batch_size * prefill_seq_len, self.fd_config.model_config.hidden_size],
        #     dtype=act_tensor_dtype,
        # )

        # forward_meta = self.create_forward_meta(
        #     batch_size=prefill_batch_size,
        #     seq_len=prefill_seq_len,
        #     mode=ForwardMode.EXTEND,
        #     fd_config=self.fd_config,
        #     attn_backend=self.attn_backend,
        #     cache_quant_type_str=self.cache_quant_type_str,
        # )

        # self.attn_backend.init_attention_metadata(forward_meta)
        # self.attn_forward(forward_meta, prefill_hidden_states)

        # paddle.device.synchronize()

        # import paddle.profiler as profiler
        # p = profiler.Profiler(
        #     targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
        #     on_trace_ready=profiler.export_chrome_tracing("./profile_log"),
        # )
        # p.start()
        # p.step()

        # start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
        # end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
        # for i in range(test_steps):
        #     start_events[i].record()

        #     self.attn_forward(forward_meta, prefill_hidden_states)

        #     end_events[i].record()
        # paddle.device.synchronize()

        # times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
        # print(times[-5:])
        # return

        # p.stop()

        # p = profiler.Profiler(
        #     targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU],
        #     on_trace_ready=profiler.export_chrome_tracing("./profile_log"),
        # )

        # p.start()
        # p.step()

        for decode_batch_size in [32, 16, 8, 4, 2]:
            forward_meta, hidden_states = self.create_forward_meta(
                batch_size=decode_batch_size,
                seq_len=36 * 1024,
                mode=ForwardMode.DECODE,
                fd_config=self.fd_config,
                attn_backend=self.attn_backend,
                cache_quant_type_str=self.cache_quant_type_str,
            )

            self.attn_backend.init_attention_metadata(forward_meta)

            paddle.device.synchronize()

            # 必须要先预热一次！因为预处理被放到了第一层再做了！
            self.attn_forward(forward_meta, hidden_states)

            attn_cuda_graphs = graphs.CUDAGraph()
            attn_cuda_graphs.capture_begin()

            self.attn_forward(forward_meta, hidden_states)

            attn_cuda_graphs.capture_end()

            start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
            end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_steps)]
            for i in range(test_steps):
                start_events[i].record()

                attn_cuda_graphs.replay()

                end_events[i].record()
            paddle.device.synchronize()

            times = np.array([round(s.elapsed_time(e), 1) for s, e in zip(start_events, end_events)])[1:]
            print(times[-5:])

            del forward_meta

        # p.stop()


if __name__ == "__main__":
    unittest.main()
